from dataclasses import dataclass

@dataclass
class ModelArgs: 
    embedding_dim: int = 256
    hidden_dim: int = 64
    num_classes: int = 2 
    n_splits: int = 5
    epochs: int = 1
    batch_size: int = 128
    num_layers: int = 8
    dropout: float = 0.3
    num_heads: int = 8
    filter_nums: int = 128
    out_channels: int = 512
    filter_sizes: tuple = (32, 64, 128)
    gpu_id: int = 2
    lr: float = 1e-4
    attention: str = 'eager'  # 'flash_attention_2', 'eager', 'sdpa'